#include "asan_internal_defs.h"
#include "asan_interceptors.h"
#include "asan_alloctor.h"
#include "asan_report.h"
#include "sanitizer_libc.h"
#include "interception.h"
#include "asan_utils.h"
#include "sanitizer_printf.h"
#include <wchar.h>
#include <stdarg.h>
#include <stdio.h>

using namespace __sanitizer;

#define INTERCEPTOR_VREPORT \
  VReport(ASAN_LOG_DEBUG, "%s:%d %s\n", __FILE__, __LINE__, __func__);

INTERCEPTOR(void, free, void *ptr) {
  INTERCEPTOR_VREPORT
  if (asan_inited) {
    VReport(ASAN_LOG_DEBUG, "free %p\n", ptr);
    allocatorPtr->DoFree(ptr);
    REAL(free(ptr));
  }
  // All memory allocated before __asan_init() will not be freed
  // Anyway, it is mere 1 page (see ASAN_PAGE_SIZE)
}

INTERCEPTOR(void*, malloc, uptr size) {
  INTERCEPTOR_VREPORT
  if (!asan_inited) {
    return internal_malloc(size);
  }
  void* addr = REAL(malloc(size));
  VReport(ASAN_LOG_DEBUG, "malloc %p %llu\n", addr, size);
  allocatorPtr->DoMalloc(addr, size);
  return addr;
}

INTERCEPTOR(void*, calloc, uptr num, uptr size) {
  if (!asan_inited) {
    return internal_calloc(num, size);
  }
  INTERCEPTOR_VREPORT
  void* addr = REAL(calloc(num, size));
  VReport(ASAN_LOG_DEBUG, "calloc %p %llu %llu\n", addr, num, size);
  allocatorPtr->DoMalloc(addr, num * size);
  return addr;
}

INTERCEPTOR(void*, realloc, void* ptr, uptr size) {
  INTERCEPTOR_VREPORT
  // we first ensure the ptr is malloc-ed
  if (ptr != nullptr) {
    // we perform checks when the ptr is not NULL
    AsanChunk* old_chunk = allocatorPtr->FindChunkWithBeg(ptr);
    if (old_chunk == nullptr) {
      Printf("realloc with an address not malloc-ed yet.\n");
      Die();
    }
    allocatorPtr->DoQuickFree(old_chunk);
    // we cannot resize the chunk directly
    // as the realloc can return a new address
    void* addr = REAL(realloc(ptr, size));
    if (size > 0) {
      allocatorPtr->DoMalloc(addr, size);
    }
    VReport(ASAN_LOG_DEBUG, "realloc %p %llu\n", addr, size);
    return addr;
  }
  else {
    // behaves like a malloc
    void* addr = REAL(realloc(ptr, size));
    if (size > 0) {
      VReport(ASAN_LOG_DEBUG, "realloc %p %llu\n", addr, size);
      allocatorPtr->DoMalloc(addr, size);
    }
    return addr;
  }
}

INTERCEPTOR(void*, alloca, uptr size) {
  INTERCEPTOR_VREPORT
  void* addr = REAL(alloca(size));
  VReport(ASAN_LOG_DEBUG, "alloca %p %llu\n", addr, size);
  allocatorPtr->DoAlloca(addr, size);
  return addr;
}

INTERCEPTOR(int, memcmp, const void *a1, const void *a2, uptr size) {
  INTERCEPTOR_VREPORT
  CheckAndReport(a1, size, false);
  CheckAndReport(a2, size, false);
  return REAL(memcmp(a1, a2, size));
}

//============================================================
INTERCEPTOR(uptr, strlen, const char *s) {
  INTERCEPTOR_VREPORT
  uptr result = internal_strlen(s);
  CheckAndReport(s, result + 1, false);
  return result;
}

INTERCEPTOR(uptr, strnlen, const char *s, uptr maxlen) {
  INTERCEPTOR_VREPORT
  // if (!asan_inited) return internal_strlen(s);
  uptr result = internal_strnlen(s, maxlen);
  if (result < maxlen)
    CheckAndReport(s, result + 1, false);
  else
    CheckAndReport(s, result, false);
  return result;
}

INTERCEPTOR(char *, strcat, char *to, const char *from) {
  INTERCEPTOR_VREPORT
  uptr from_length = internal_strlen(from);
  CheckAndReport(from, from_length + 1, false);
  uptr to_length = internal_strlen(to);
  CheckAndReport(to, to_length + from_length + 1, true);
  return REAL(strcat)(to, from);
}

INTERCEPTOR(char *, strncat, char *to, const char *from, uptr size) {
  INTERCEPTOR_VREPORT
  uptr copy_length = internal_strnlen(from, size);
  if (copy_length < size)
    CheckAndReport(from, copy_length + 1, false);
  else
    CheckAndReport(from, copy_length, false);
  uptr to_length = internal_strlen(to);
  CheckAndReport(to, to_length + copy_length + 1, true);
  return REAL(strncat(to, from, size));
}

INTERCEPTOR(char *, strcpy, char *to, const char *from) {
  INTERCEPTOR_VREPORT
  uptr from_length = internal_strlen(from);
  CheckAndReport(from, from_length + 1, false);
  CheckAndReport(to, from_length + 1, true);
  REAL(memcpy(to, from, from_length + 1));
  return to;
  // return REAL(strcpy(to, from));
}

INTERCEPTOR(char *, strncpy, char *to, const char *from, uptr size) {
  INTERCEPTOR_VREPORT
  uptr from_length = internal_strnlen(from, size);
  if (from_length == size)
    CheckAndReport(from, from_length, false);
  else
    CheckAndReport(from, from_length + 1, false);
  CheckAndReport(to, from_length + 1, true);
  return REAL(strncpy(to, from, size));
}

INTERCEPTOR(char *, strdup, const char *s) {
  INTERCEPTOR_VREPORT
  uptr length = internal_strlen(s);
  CheckAndReport(s, length + 1, false);
  return REAL(strdup(s));
}

// wchar_t functions
INTERCEPTOR(wchar_t *, wcscpy, wchar_t *dest, const wchar_t *src) {
  INTERCEPTOR_VREPORT
  uptr src_len = internal_wcslen(src);
  uptr byte_len = (src_len + 1) * sizeof(wchar_t);
  CheckAndReport((void*)src, byte_len, false);
  CheckAndReport((void*)dest, byte_len, true);
  return REAL(wcscpy(dest, src));
}

INTERCEPTOR(wchar_t *, wcsncpy, wchar_t *dest, const wchar_t *src, uptr n) {
  INTERCEPTOR_VREPORT
  uptr src_len = internal_wcsnlen(src, n);
  uptr read_len = src_len * sizeof(wchar_t);
  if (src_len < n) {
    read_len += sizeof(wchar_t);
  }
  CheckAndReport((void*)src, read_len, false);
  CheckAndReport((void*)dest, (src_len + 1) * sizeof(wchar_t), true);
  return REAL(wcsncpy(dest, src, n));
}

INTERCEPTOR(int, printf, const char *format, ...) {
  INTERCEPTOR_VREPORT
  // We merely check if the format is a valid string pointer with valid length
  // check the pointer at initial
  CheckAndReport((void*)format, 1, false);
  uptr read_len = REAL(strlen(format));
  CheckAndReport((void*)format, (read_len + 1) * sizeof(char), false);
  va_list args;
  va_start(args, format);
  va_list copy_args;
  va_copy(copy_args, args);
  CheckPrintfVars(format, copy_args);
  int ret = vprintf(format, args);
  va_end(copy_args);
  va_end(args);
  return ret;
}

INTERCEPTOR(int, snprintf, char * s, size_t n, const char * format, ...) {
  INTERCEPTOR_VREPORT
  // An aggressive implement
  CheckAndReport((void*)s, n, true);
  CheckAndReport((void*)format, 1, false);
  uptr read_len = REAL(strlen(format));
  CheckAndReport((void*)format, (read_len + 1) * sizeof(char), false);
  va_list args;
  va_start(args, format);
  va_list copy_args;
  va_copy(copy_args, args);
  CheckPrintfVars(format, copy_args);
  int ret = vsnprintf(s, n, format, args);
  va_end(copy_args);
  va_end(args);
  return ret;
}

// CXX function
CXX_INTERCEPTOR(const char*, strstr, const char *haystack, const char *needle) {
  INTERCEPTOR_VREPORT
  uptr len1 = internal_strlen(haystack);
  CheckAndReport(haystack, len1, false);
  uptr len2 = internal_strlen(needle);
  CheckAndReport(needle, len2, false);
  if (len1 < len2) return nullptr;
  for (uptr pos = 0; pos <= len1 - len2; pos++) {
    if (internal_memcmp(haystack + pos, needle, len2) == 0)
      return const_cast<char *>(haystack) + pos;
  }
  return nullptr;
}

CXX_INTERCEPTOR(const char*, strchr, const char *str, int c) {
  INTERCEPTOR_VREPORT
  // const char* result = REAL(strchr(str, c));
  const char* result = internal_strchr(str, c);
  if (result != nullptr)
    CheckAndReport(str, (uptr)result - (uptr)str + 1, false);
  return result;
}